
{--
    Java code generation for the frege compiler

    This package deals with pattern matching and @case@ statements.
 -}


package frege.compiler.gen.java.Match where


import frege.Prelude hiding(apply, <+>)
import Data.TreeMap as TM(TreeMap, values, keys, each, insert, lookup)
import Data.List as DL(sortBy, partitioned)

import  Compiler.enums.Literals

import  Compiler.types.Strictness

-- import  Compiler.types.JNames
import  Compiler.types.QNames
import  Compiler.types.Types
import  Compiler.types.Patterns
import  Compiler.types.Expression
import  Compiler.types.ConstructorField
import  Compiler.types.Symbols
import  Compiler.types.Global as G
import  Compiler.types.Positions
import  Compiler.types.AbstractJava

import  Compiler.common.Mangle
import  Compiler.common.Errors as E()
import  Compiler.common.Types as CT
import  Compiler.common.JavaName


import  Compiler.classes.Nice

import  Compiler.enums.Flags(TRACEG)

import frege.compiler.Utilities     as U() 
import frege.lib.PP(text, <>, <+>, <+/>, </>)  
import Compiler.common.Trans as T(patternStrictness)
import frege.compiler.Typecheck     as TY()

import  Compiler.gen.java.Common     as GU(lambdaType, sComment, namedFields, strict, boxed,
                                            tauJT, substJT,
                                            sigmaJT, jstFind, argTypeB, asKinded, fromKinded) 
import  Compiler.gen.java.Bindings   as GB
import  Compiler.gen.java.Constants(staticConst)


infixr 6 `<>`

{--
  * [usage] @match assert pattern bind continuation bindings@
  * [returns] a list of java statements and an updated binding
  * generate code and/or extend current @bindings@ that performs and/or reflects
  * a match of @pattern@
  * against the java expression in @bind@ and generate the code for a successful match
  * by applying @continuation@ to the extended bindings.
  *
  * Generated code will look like this:
  * > comment
  * > auxiliary local definitions
  * > if (patternmatches) {
  * >           code generated by continuation
  * > }
  * If the @assert@ switch is on, this is the last possible match and
  * an assert statement is generated instead of the if. This is useful in 
  * code like:
  * > case foo of 
  * >   Just bar -> ...
  * >   Nothing -> ...
  * The @Nothing@ does not need matching if @bar@ was irrefutable.
  *
  * There does not have to be an @if@ or any other code at all, for example when
  * the pattern is irrefutable. The code generated by the continuation *must* return,
  * if the control flow reaches the closing brace of the if this will be
  * an indication that the pattern match failed.
  -}
match :: Bool 
            -> Pattern 
            -> Binding 
            -> (TreeMap Int Binding -> StG [JStmt]) 
            -> TreeMap Int Binding 
            -> StG (Binding, [JStmt])
match assert (PVar {pos,uid,var}) bind cont binds = do
         vsym <- U.findV local
         g    ←  getST

         let strict = vsym.strsig.isStrict
             sls    = if strict then "strict" else "lazy"

         E.logmsg TRACEG pos (text "match" <+> text sls <+> text "var `" <+> text var 
                <+> text "´ bound to " 
                <+> text (show bind)
            )

        -- make sure a lazy variable is never realized with, e.g., arg$1.call
         let unCall bnd
                | JInvoke ex [] ← bnd.jex,
                  JExMem{jex=arg, name="call"} ← ex
                = bnd.{jtype ← Lazy, jex = arg}
                | otherwise = bnd

         let sbind = if strict then strictBind g bind else unCall bind
         (rbind, code) <- if strict then realize (jname g) sbind else pure (sbind, [])
         let stmt
                    | var == "_" = code
                    | otherwise  = sComment ("bind " ++ sls ++ " var " ++ nice (Symbol.name vsym) g
                                         ++ "  to  " ++ show rbind) : code
             nbinds = insert uid rbind binds
         rest <- cont nbinds
         pure (rbind, stmt++rest)
     where
         local = Local uid var
         jname g = (javaName g local).base

 
match assert (p@PAt {pat,uid,var}) bind cont binds = do
        g <- getST
        ps <- T.patternStrictness pat
         -- let patty = patternRMode g pat
        let local = Local uid var
            jname = (javaName g local).base
        vsym <- U.findV local
        let comment bind = sComment ("match " 
                        ++ nice p g ++ "::" ++ nicer vsym.typ g 
                        ++ "  with  " ++ show bind)
            -- bs = isStrictJT bind.jtype && isKnownJT bind.jtype
            ss = Strictness.isStrict vsym.strsig
        case (ps.isStrict, ss) of
            (_, true) -> do
                (rbind, code1) <- realize jname (strictBind g bind)
                let nbinds = insert uid rbind binds
                (xbind, code2) <- match assert pat rbind cont nbinds
                stio (xbind, comment rbind: code1++code2)
            -- otherwise we change the type of the argument, i.e. String as Lazy String and then
            -- via Delayed.<String>forced() back
            (true, false) -> do 
                -- (rbind, code1) <- realize "$" (adaptSigma g bind)
                (vbind, code2) <- realize jname (strictBind g bind)
                let nbinds = insert uid vbind binds
                (xbind, code3) <- match assert pat vbind cont nbinds
                stio (xbind, comment vbind: code2++code3)
            (false, false) -> do
                (rbind, code1) <- realize jname bind -- (adaptSigmaWith lazy g bind)
                let nbinds = insert uid rbind binds
                (xbind, code2) <- match assert pat rbind cont nbinds
                stio (xbind, comment rbind: code1++code2)
        
 
match assert (pat@PLit {kind=LBool, value}) bind cont binds = do
     g <- getST
     body <- cont binds
     let comment = sComment ("match  " ++ nice pat g
                                 ++ "  with  " ++ show bind)
         sbnd = strictBind g bind
         jex = if value == "true" then sbnd.jex else JUnop "!" sbnd.jex
         ifc = if assert then JAssert jex : body else [JCond "if" jex body]
     stio (bind, comment:ifc)
 
match assert (pat@PCon {pos,qname,pats}) bind  cont binds = do
         -- g <- getST
         symd <- U.findD qname                   -- forall a.a -> List a -> List a
         symt <- U.findT symd.name.tynm          -- forall a.List a 
         if symt.enum then matchEnum symd symt
             else if symt.product
                 then if symt.newt
                     then matchNew     symd symt 
                     else matchProd    symd symt -- pat bind cont binds
                 else matchVariant symd symt -- pat bind cont binds
     where
         unKindedStrict g lbnd = case strictBind g lbnd of
                    kbnd ->  case kbnd.jtype of
                        Kinded{} → adapt g kbnd (fromKinded kbnd.jtype)
                        other    → kbnd
         comment g = sComment ("match  " ++ nice pat g
                                 ++ "  with  " ++ show bind)
         -- matchNewt :: Symbol -> Symbol -> StG (Binding, [JStmt])
         -- matchNewt symd symt = match (head pats) bind cont binds
         matchEnum :: Symbol -> Symbol -> StG (Binding, [JStmt])
         matchEnum symd symt = do
             g <- getST
             let sbnd = unKindedStrict g bind
             -- (bind, code1) <- realize "$" sbnd
             body  <- cont binds
             let comp = JBin sbnd.jex "==" (JX.staticMember (symJavaName g symd)) 
                 ifc  = if assert then JAssert comp : body else [JCond "if" comp body]
             stio (sbnd, comment g : ifc)
 
         matchNew :: Symbol -> Symbol -> StG (Binding, [JStmt])
         matchNew symd symt = do
            g <- getST
            let -- box0    = adaptSigma g bind
                arg     = symd.typ.rho.sigma                  -- first arg of data con
                tree    = unifySigma g symt.typ bind.ftype    -- instantiate type args a -> Int
                sig     = substSigma tree arg                 -- substitute in arg
            E.logmsg TRACEG (getpos pat) (
                    text "matchNew:" <+> text (nicer pat g) </> PP.nest 4 (
                            text "arg::" <+> text (nice arg g)
                        </> text "bnd::" <+> text (nice bind.ftype g)
                        </> text "sig::" <+> text (nice sig g) 
                    )
                )
                -- pretend our binding is the pattern
            let  box1 = (newBind g sig bind.jex).{jtype = bind.jtype}
            match assert (head pats) box1 cont binds

         matchVariant :: Symbol -> Symbol -> StG (Binding, [JStmt])
         matchVariant symd symt = do
            g <- getST
            E.logmsg TRACEG (getpos pat) (text "match pattern "
                <+> text (nicer pat g)
                <+> text " with "
                <+> (text . show) bind) 
            let box1 = unKindedStrict g bind                  -- List Int
                tree = unifySigma g symt.typ bind.ftype       -- a -> Int 
                rho  = substRho tree symd.typ.rho             -- Int -> List Int -> List Int
            (boxd, code1) <- realize "$" box1                 -- TList $1 = .....
 
            let  -- smode = any Strictness.isStrict pss
                 cname = if symt.product then "" else conGetter qname   -- _DCons
                 vbind = if symt.product then boxd else
                            Bind boxd.stype boxd.ftype
                                (variantType g boxd.jtype symd) 
                                (JInvoke (JX.jexmem boxd.jex cname) [])
 
            (varb, code2) <- if symt.product then return (boxd, [])
                             else realize "$" vbind             -- TList.DCons $2 = $1._DCons()

            E.logmsg TRACEG pos (text "match constructor "
                    <+> text (nicer symd g)
                    </> text "realized at "
                    <+> text (show varb)
                    </> text "fields:"
                    <+> PP.stack (map (text . flip nicer g . _.typ) symd.flds)
                )



            let sigs = snd $ U.returnType symd.typ.rho               -- [b, List b]
                -- set up the expressions that are to be matched by sub patterns
                obinds = zipWith (fldBind g varb) (namedFields symd.flds) sigs
                rsigs = snd $ U.returnType rho
                jtree = fmap (boxed . tauJT g) tree
                subst bind sig = bind.{ftype=sig, jtype ← substJT jtree}
                pbinds = zipWith subst obinds rsigs


            -- make sure refutable patterns are matched first so that
            -- evaluation of lazy values that are bound to variables does
            -- not occur before it is sure that the overall match succeeds
            let (zpats, zbinds) = (unzip • reverse • sortBy (comparing (T.patternRefutable g • fst))) 
                            (zip pats pbinds)
            rest <- matches assert zpats zbinds cont binds     
 
            let notnull = JBin varb.jex "!=" (JAtom "null")
                ifn = if symt.product
                        then rest
                        else if assert  then JAssert notnull : rest 
                                        else [JCond "if" notnull rest]

            stio (boxd, (comment g : code1) ++ code2 ++ ifn)
 
         matchProd :: Symbol -> Symbol -> StG (Binding, [JStmt])
         matchProd symd symt = matchVariant symd symt   -- for the time being
         
 
match assert (pat@PLit {kind=LString, value}) bind cont binds = do
     g <- getST
     (sbnd,code) <- realize "$" (strictBind g bind)
     body <- cont binds
     let comment = sComment ("match  " ++ nice pat g
                                 ++ "  with  " ++ show bind)
         -- sbnd = adaptSigma g bind
         jex = JInvoke (JX.jexmem (JAtom value) "equals") [sbnd.jex]
         ifc = if assert then JAssert jex : body else [JCond "if" jex body]
     stio (sbnd, (comment:code) ++ ifc)
 
match assert (pat@PLit {kind, value, negated}) bind cont binds
     | kind `elem` [LChar, LInt, LLong, LDouble, LFloat], not negated = do
         g <- getST
         (sbnd,code) <- realize "$" (strictBind g bind)
         body <- cont binds
         let comment = sComment ("match  " ++ nice pat g
                                     ++ "  with  " ++ show bind)
             jex = JBin (JAtom value) "==" sbnd.jex
             ifc = if assert then JAssert jex : body else [JCond "if" jex body]
         stio (sbnd, (comment:code) ++ ifc)
     | kind == LBig = do
         g <- getST
         (sbnd, code) <- realize "$" (strictBind g bind)
         body <- cont binds
         let comment = sComment ("match  " ++ nice pat g
                                     ++ "  with  " ++ show bind)
             -- sbnd = adaptSigma g g bind
             lit = Lit {pos = pat.pos, kind = LBig, value, typ = Just TY.sigInteger, negated}
         jst <- staticConst lit
         let xbnd =  Bind{ftype = TY.sigInteger, 
                        stype = nicer TY.sigInteger g, 
                        jtype = sigmaJT g TY.sigInteger, 
                        jex = jst}
         let jex = JInvoke (JX.jexmem xbnd.jex "equals") [sbnd.jex]
             ifc = if assert then JAssert jex : body else [JCond "if" jex body]
         stio (sbnd, (comment:code) ++ ifc)
     | isLiteralNumeric kind = do
        g <- getST
        (sbnd, code) <- realize "$" (strictBind g bind)
        body ← cont binds
        let comment = sComment ("match  " ++ nice pat g
                                     ++ "  with  " ++ show bind)
            lit = Lit {pos = pat.pos, kind, value, typ = Nothing, negated}
        jst ← staticConst lit
        let jex  = JBin jst  "=="  sbnd.jex 
            ifc = if assert then JAssert jex : body else [JCond "if" jex body]
        pure (sbnd, (comment:code) ++ ifc)
     | kind == LRegex = do
         g <- getST
         (sbnd,code) <- realize "$" (strictBind g bind)
         body <- cont binds
 
         let comment = sComment ("match  " ++ nice pat g
                                 ++ "  with  " ++ show bind)
             -- sbnd = adaptSigma g g bind
             lit = Lit {pos = pat.pos, kind = LRegex, value, typ = Just (TY.sigRegex), negated=false}
         jst <- staticConst lit
         let xbnd = Bind{ftype = TY.sigRegex, 
                        stype = nicer TY.sigRegex g, 
                        jtype = sigmaJT g TY.sigRegex, 
                        jex = jst}
 
         let matcher = JInvoke (JX.jexmem xbnd.jex "matcher") [sbnd.jex]
             jex = JInvoke (JX.jexmem matcher "find") []
             ifc = if assert then JAssert jex : body else [JCond "if" jex body]
         stio (sbnd, (comment:code) ++ ifc)
 
match assert (pat@PMat {pos, uid, var, value}) bind cont binds = do
         g <- getST
         vsym <- U.findV (Local uid var)
         (sbnd,code) <- realize "$" (strictBind g bind)
         let mjt = sigmaJT g TY.sigMatcher
         let comment = sComment ("match  " ++ nice pat g
                                 ++ "  with  " ++ show bind)
             
             lit = Lit {pos, kind = LRegex, value, typ = Just (TY.sigRegex), negated = false}
 
         jst <- staticConst lit
         let mbnd = Bind{ftype = TY.sigMatcher, stype = nicer TY.sigMatcher g, 
                         jtype = mjt,
                         jex = JInvoke jstFind  [sbnd.jex, jst]}
         (mbnd,code2) <- realize (javaName g (Local uid var)).base mbnd
         body <- cont (binds.insert uid mbnd)
 
         let  -- jex = JInvoke (JX.jexmem mbnd.jex "find") []
              jex = JBin{j1 = mbnd.jex, op = "!=", j2 = JAtom "null"}
              ifc = if assert then JAssert jex : body else [JCond "if" jex body]
         stio (sbnd, (comment:code) ++ code2 ++ ifc)
 
match assert (PAnn {pat})    bind cont binds = match assert pat bind cont binds
match assert (PUser {pat})   bind cont binds = match assert pat bind cont binds
match _ pat b c bs = do
     g <- getST
     stio (b, [JError ("match  " ++ nice pat g ++ "  with  " ++ show b)])
 
{--
  * A variant of 'match' that matches the components of a product against a pattern
  *
  * @pat@ must be a constructor application whose constructor is the same
  * as given in @con@
  -}
matchCon assert (PCon {pos,qname, pats}) con bexs cont binds = do
         g <- getST
         sym <- U.findD qname
         if sym.sid != Symbol.sid con
             then do
                 E.fatal pos (text ("matchCon: " ++ nice qname g ++ " against " ++ nice con g))
             else do
                 -- make sure refutable patterns are matched first so that
                 -- realization of strict variables does not occur outside an if
                 ppbs = (reverse • sortBy (comparing (T.patternRefutable g • fst))) (zip pats bexs)
                 matches assert (map fst ppbs) (map snd ppbs) cont binds
matchCon assert pcon con bexs cont binds = error "matchCon: no constructor"

--- Match a list of subpatterns against a list of subexpresssions 
matches assert []     []     cont binds = cont binds
matches assert (p:ps) (b:bs) cont binds = do
         (_, code) <- match assert p b (matches assert ps bs cont) binds
         stio code
matches assert _ _ _ _ = Prelude.error "matches: cannot happen when compiler is sane"
 

{-
    tell at what 'JType' the case expression should best be computed to match pattern

patternRMode g p =
    case p of
        PVar {uid,var} -> case (Local uid var).findit g of
            Just sym -> if sym.strsig.isStrict
                 then strict (sigmaJT g sym.typ)
                 else lazy   (sigmaJT g sym.typ)
            other -> error ("patternRMode: var not found  " ++ nicer p g)
        PAt {pos,uid,var,pat} -> patternRMode g PVar{pos, uid, var}
        PUser {pat, lazy} -> if lazy 
            then GU.lazy   (patternRMode g pat) 
            else GU.strict (patternRMode g pat)
        PLit{kind} -> case kind of
                       LBool -> Nativ "boolean" []
                       LChar -> Nativ "char" []
                       LString -> jtString
                       LInt -> jtInt
                       LBig -> Nativ "java.math.BigInteger" []
                       LLong -> Nativ "long" []
                       LFloat -> Nativ "float" []
                       LDouble -> Nativ "double" []
                       LRegex -> jtString

        PMat {pos} ->jtString
        PAnn {pat} -> patternRMode g pat
        PCon {pos,qname,pats} = case QName.findit qname g of
            Just symd -> case (Symbol.name symd).tynm.findit g of
                Just symt -> sigmaJT g (Symbol.typ symt)
        PConFS {pos} -> error ("patternRMode: found PConFS") 
-}
 
{--
  * [usage] @conGetter qname@
  * [return] the name of the method that gets the variant
  * [example] @conGetter (MName tname "Con")@ evaluates to @"_Con"@
  -}
conGetter (MName tname base) = "as" ++ (mangled base)
conGetter _ = error "conGetter: no member"
 
 
{--
  * [usage] @variantType g jtype symd@
  * [return] the type of the variant, i.e. @tMaybe.dJust<a>@, if jtype is the jt of the type
  -}
variantType :: Global -> JType -> Symbol -> JType
variantType g jtype symd    -- | traceLn("variantType for " ++ show jtype) || true 
    = jtype.{jname = symJavaName g (symd)}

{--
    Instantiate a field of an algebraic value at a given type.
    
     [value] the (strict) 'Binding' that holds the java expression for the value
     [field] the field in question
     [at] the type this is to be instantiated at
     
    Returns a new binding with the java expression that accesses the field.  
    -}
fldBind :: Global -> Binding -> ConField QName -> Sigma -> Binding
fldBind g value field at  = Bind{
        stype = nicer at g,
        ftype = at, 
        jtype = fty,
        jex = JExMem{jex=(strictBind g value).jex, name=unJust field.name, targs=[]}}
    where jty = argTypeB g field.strict at
          sty = sigmaJT g field.typ
          fldka = case sty of
            Kinded{arity} = arity
            other         = 0
          fty
            | Lazy{} ← jty = Lazy (asKinded (boxed jty) fldka)
            | otherwise = asKinded jty fldka